import random
from glob import glob
import numpy as np
import torch
from matplotlib import image
from matplotlib import pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchvision import transforms
from torchvision.models import ResNet50_Weights, resnet50
# visualize some images from the Imagenette dataset, from the "Tench" class
image_list_test = glob(
r"C:\Users\BanbhanAbdulBasit\OneDrive - Johannes Kepler Universität Linz\Studium\Master\1WS2022\XAI\UE\xai_resnet50\pics\imagenette2-320\train\n02102040\*.JPEG"
)
f, axarr = plt.subplots(5, 5, figsize=(35, 35))
for i in range(5):
for j in range(5):
axarr[i, j].imshow(image.imread(random.choice(image_list_test)))
axarr[i, j].axis("off")
plt.show()